perf(dpa4): opt so3grid (with pt_expt GridProduct wrapping fix)#5552
Conversation
The parameter-free `GridProduct` NativeOP (added for the so3grid
optimization) has no `serialize`/`deserialize` and is not registered via
`register_dpmodel_mapping`. The pt_expt backend auto-wraps every dpmodel
NativeOP sub-component through `_auto_wrap_native_op`, which requires the
op to be serializable (or registered) to build its dynamic torch wrapper;
otherwise it raises:
TypeError: Cannot auto-wrap GridProduct: it must implement
serialize()/deserialize() or be explicitly registered via
register_dpmodel_mapping().
This broke every `Test Python` shard that loads a DPA4 pt_expt model
(e.g. test_get_model_dpa4.py). Add trivial `serialize`/`deserialize`
(no state, mirroring the GridMLP @class/@Version convention) so the op
auto-wraps cleanly.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (5)
📝 WalkthroughWalkthroughPorts ChangesPort GridMLP and refactor grid-op projector contract
Sequence Diagram(s)sequenceDiagram
participant EquivariantFFN
participant S2GridNet
participant BaseGridNet
participant GridOp as GridProduct/GridMLP/GridBranch
participant _to_grid
participant _from_grid
EquivariantFFN->>S2GridNet: forward(left, right, scalar_pair)
S2GridNet->>BaseGridNet: call(left, right, scalar_pair)
BaseGridNet->>BaseGridNet: _apply_grid_op(left, right, scalar_pair)
BaseGridNet->>GridOp: call(left, right, scalar_pair, to_grid=_to_grid, from_grid=_from_grid)
GridOp->>_to_grid: project coefficients to grid space
_to_grid-->>GridOp: grid tensor (channel width inferred from shape)
GridOp->>GridOp: quadratic product / MLP / branch routing in grid space
GridOp->>_from_grid: project grid back to coefficient space
_from_grid-->>GridOp: coeff_out
GridOp-->>BaseGridNet: coeff_out
BaseGridNet-->>S2GridNet: coeff_out
S2GridNet-->>EquivariantFFN: output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5552 +/- ##
==========================================
+ Coverage 82.21% 82.24% +0.02%
==========================================
Files 892 894 +2
Lines 101532 102084 +552
Branches 4240 4276 +36
==========================================
+ Hits 83475 83955 +480
- Misses 16753 16828 +75
+ Partials 1304 1301 -3 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
… pt) (deepmodeling#5555) ## What Completes the DPA4/SeZM **SO3 grid projection** port to the dpmodel backend so it faithfully **mirrors master's current pt** `sezm_nn/grid_net.py`. After this, the flagship `examples/water/dpa4/input.json` (which sets `ffn_so3_grid=true`, `message_node_so3=true`, `grid_mlp`) runs on dpmodel/pt_expt. Builds **on top of** the S2-grid base that deepmodeling#5517/deepmodeling#5552 landed (`GridProduct`/`GridMLP`/`op_type='mlp'`, the `_project_frames` refactor). Master's dpmodel was the S2 (`n_frames==1`) slice with SO3/cross-mode fail-fast guarded; this PR generalizes those ops to frame-aware (`n_frames>1`) + cross-mode and adds the missing SO3 pieces — matching current pt exactly (single source of truth: dpmodel == pt). Supersedes deepmodeling#5547 (which ported the *pre*-deepmodeling#5552 design and went structurally stale). ## Changes (all mirror current pt) - **`grid_net.py`**: add `_project_frames`; generalize `GridMLP`/`GridBranch` to frame-aware (`n_frames`); generalize `BaseGridNet` (un-guard `mode='cross'`, `layout='flat'`, `residual_scale_init`, `n_frames>1`; frame-axis to/from-grid via `xp.matmul`+reshape); add `FrameContract`/`FrameExpand`/`_build_frame_degree_index`; add `SO3GridNet` (self+cross). - **`projection.py`**: add `SO3GridProjector` (Wigner-D quadrature) + `resolve_so3_grid`/`_build_so3_frame_set`. - **`ffn.py`**: un-guard `ffn_so3_grid` → `SO3GridNet(mode='self')`. - **`so2.py`**: un-guard `node_wise_{s2,so3}`/`message_node_{s2,so3}` → cross-mode grid products, applied in `call` + round-tripped in serialize. ## Validation - Component parity vs pt (weight-copied fp64): `_project_frames`, `GridMLP`/`GridBranch` (incl. S2 byte-identical regression), `BaseGridNet` cross/flat/residual, `FrameContract`/`FrameExpand`, `SO3GridProjector` matrices, `SO3GridNet` self+cross (op_type glu/mlp/branch, kmax 1&2) — all **1e-12**; rotation equivariance **1e-10**. - **fp32** grid-path parity at the computation-in-fp32 budget (actual diffs 1e-6–1e-8 ≪ 1e-4). - Full-descriptor pt→dpmodel via `DescrptDPA4.deserialize(pt.serialize())` on the example config (lmax=3, mmax=1) — **~1e-14** — proving `dp convert-backend` schema interop. - Permutation-invariance + masked-edge no-op. - Cross-backend consistency rows (pt vs dpmodel **and pt_expt**, mixed_types) for ffn_so3_grid / message_node_so3 / both / grid_mlp. - **Verified on remote GPU (Tesla T4):** 617 (grid+parity+pt_expt) + 50 (consistency) pass, no CUDA device errors. pt_expt forward works today via auto-wrap (consistency + descriptor trio green) — no explicit registration needed. ## Known limitations - pt_expt **training** Parameter-promotion for the new weight-bearing grid classes, `torch.export`/AOTI grid coverage, training e2e, argcheck `doc_only_pt_supported` removal, and freeze/DeepEval are a **follow-up PR**. - `grid_method='e3nn'` (non-Lebedev product grid) stays fail-fast (Lebedev-only, per parent design). - fp32 grid paths use a ~1e-4 budget by design; fp64 is the parity reference. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added SO(3) grid projection support and frame-aware grid networks for DPA4 descriptors, including SO(3)-based FFN and improved cross-mode grid-product wiring. * Extended grid modules to support multi-frame configurations with per-degree frame mixing, and added full SO(3) projector/network serialization. * **Bug Fixes** * Enabled previously disabled/unsupported DPA4 SO(2) convolution cross-mode SO(3)/S2 grid products. * **Documentation** * Updated DPA4 porting-layer documentation to clarify supported configuration flags. * **Tests** * Added/expanded parity, equivariance, serialization/roundtrip, and torch-namespace compatibility tests for the new SO(3) and frame-aware paths. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Based on #5517 (
perf(dpa4): opt so3gridby @OutisLi) — this branch contains all of its commits plus one fix commit that addresses the CI failures on that PR.Problem
#5517 introduces a new parameter-free
GridProductNativeOPindeepmd/dpmodel/descriptor/dpa4_nn/grid_net.pyfor the so3grid optimization, but it has noserialize/deserializeand is not registered viaregister_dpmodel_mapping. The pt_expt backend auto-wraps every dpmodelNativeOPsub-component through_auto_wrap_native_op, which requires the op to be serializable (or registered) to build its dynamic torch wrapper. Otherwise it raises:This broke every
Test Pythonshard that loads a DPA4 pt_expt model (e.g.source/tests/pt_expt/model/test_get_model_dpa4.py::TestGetModelDPA4::test_pair_exclude_types_from_descriptor) on #5517.Fix
Add trivial
serialize/deserializetoGridProduct(no state — mirrors theGridMLP@class/@versionconvention)._auto_wrap_native_opthen passes itshasattr(value, "serialize")guard and returnswrapped_cls.deserialize(value.serialize())cleanly.Notes
GridMLP(also new in perf(dpa4): opt so3grid #5517) already implementsserialize/deserialize; only the parameter-freeGridProductwas missing them._auto_wrap_native_opcode path (deepmd/pt_expt/common.py:138-170); the actual pt_expt DPA4 test runs in CI here.Summary by CodeRabbit
Release Notes
Refactor
Documentation
Tests